#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ChaCha20-Poly1305 AEAD 加密解密实现
支持12字节nonce和指定counter
提供认证加密功能，包含完整性保护和认证
"""

import struct


class ChaCha20Poly1305:
    def __init__(self):
        """
        初始化ChaCha20-Poly1305 AEAD
        """
        self.counter = 0

    def _quarter_round(self, a: int, b: int, c: int, d: int) -> tuple:
        """ChaCha20四分之一轮函数"""
        a = (a + b) & 0xffffffff
        d ^= a
        d = ((d << 16) | (d >> 16)) & 0xffffffff

        c = (c + d) & 0xffffffff
        b ^= c
        b = ((b << 12) | (b >> 20)) & 0xffffffff

        a = (a + b) & 0xffffffff
        d ^= a
        d = ((d << 8) | (d >> 24)) & 0xffffffff

        c = (c + d) & 0xffffffff
        b ^= c
        b = ((b << 7) | (b >> 25)) & 0xffffffff

        return a, b, c, d

    def _chacha20_block(self, key: bytes, nonce: bytes, counter: int) -> bytes:
        """生成ChaCha20密钥流块"""
        # ChaCha20常量
        constants = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574]

        # 将密钥转换为32位整数数组
        key_words = list(struct.unpack('<8I', key))

        # 计数器
        counter_word = counter

        # 将nonce转换为32位整数数组
        nonce_words = list(struct.unpack('<3I', nonce))

        # 初始状态
        state = constants + key_words + [counter_word] + nonce_words

        # 保存初始状态
        working_state = state[:]

        # 20轮操作
        for _ in range(10):
            # 列轮
            working_state[0], working_state[4], working_state[8], working_state[12] = self._quarter_round(
                working_state[0], working_state[4], working_state[8], working_state[12])
            working_state[1], working_state[5], working_state[9], working_state[13] = self._quarter_round(
                working_state[1], working_state[5], working_state[9], working_state[13])
            working_state[2], working_state[6], working_state[10], working_state[14] = self._quarter_round(
                working_state[2], working_state[6], working_state[10], working_state[14])
            working_state[3], working_state[7], working_state[11], working_state[15] = self._quarter_round(
                working_state[3], working_state[7], working_state[11], working_state[15])

            # 对角轮
            working_state[0], working_state[5], working_state[10], working_state[15] = self._quarter_round(
                working_state[0], working_state[5], working_state[10], working_state[15])
            working_state[1], working_state[6], working_state[11], working_state[12] = self._quarter_round(
                working_state[1], working_state[6], working_state[11], working_state[12])
            working_state[2], working_state[7], working_state[8], working_state[13] = self._quarter_round(
                working_state[2], working_state[7], working_state[8], working_state[13])
            working_state[3], working_state[4], working_state[9], working_state[14] = self._quarter_round(
                working_state[3], working_state[4], working_state[9], working_state[14])

        # 添加初始状态
        for i in range(16):
            working_state[i] = (working_state[i] + state[i]) & 0xffffffff

        # 转换为字节
        return struct.pack('<16I', *working_state)

    def _poly1305_clamp(self, r: bytes) -> int:
        """Poly1305密钥限制函数"""
        r_int = int.from_bytes(r, 'little')
        return r_int & 0x0ffffffc0ffffffc0ffffffc0fffffff

    def _poly1305_mac(self, message: bytes, key: bytes) -> bytes:
        """计算Poly1305 MAC"""
        if len(key) != 32:
            raise ValueError("Poly1305密钥必须是32字节")
        
        r = self._poly1305_clamp(key[:16])
        s = int.from_bytes(key[16:], 'little')
        
        p = (1 << 130) - 5  # Poly1305素数
        accumulator = 0
        
        # 处理消息块
        for i in range(0, len(message), 16):
            block = message[i:i+16]
            if len(block) < 16:
                block += b'\x01' + b'\x00' * (15 - len(block))
            else:
                block += b'\x01'
            
            n = int.from_bytes(block, 'little')
            accumulator = ((accumulator + n) * r) % p
        
        accumulator = (accumulator + s) % (1 << 128)
        return accumulator.to_bytes(16, 'little')

    def encrypt(self, plaintext: bytes, key: bytes, nonce: bytes, aad: bytes = b'') -> tuple:
        """
        ChaCha20-Poly1305 AEAD加密
        :param plaintext: 明文数据
        :param key: 32字节密钥
        :param nonce: 12字节nonce
        :param aad: 附加认证数据
        :return: (密文, 认证标签)
        """
        if len(key) != 32:
            raise ValueError("密钥必须是32字节")
        if len(nonce) != 12:
            raise ValueError("nonce必须是12字节")

        # 生成Poly1305密钥
        poly_key = self._chacha20_block(key, nonce, 0)[:32]
        
        # 重置计数器为1开始加密
        self.reset_counter(1)
        
        keystream = b''
        # 生成足够的密钥流
        while len(keystream) < len(plaintext):
            keystream += self._chacha20_block(key, nonce, self.counter)
            self.counter += 1

        # 截取需要的密钥流长度
        keystream = keystream[:len(plaintext)]

        # 异或加密
        ciphertext = bytes(a ^ b for a, b in zip(plaintext, keystream))

        # 构造认证数据
        auth_data = aad + self._pad16(aad)
        auth_data += ciphertext + self._pad16(ciphertext)
        auth_data += len(aad).to_bytes(8, 'little')
        auth_data += len(ciphertext).to_bytes(8, 'little')
        
        # 计算认证标签
        tag = self._poly1305_mac(auth_data, poly_key)
        
        return ciphertext, tag

    def _pad16(self, data: bytes) -> bytes:
        """填充到16字节边界"""
        remainder = len(data) % 16
        if remainder == 0:
            return b''
        return b'\x00' * (16 - remainder)

    def decrypt(self, ciphertext: bytes, tag: bytes, key: bytes, nonce: bytes, aad: bytes = b'') -> bytes:
        """
        ChaCha20-Poly1305 AEAD解密
        :param ciphertext: 密文数据
        :param tag: 认证标签
        :param key: 32字节密钥
        :param nonce: 12字节nonce
        :param aad: 附加认证数据
        :return: 解密后的明文
        :raises: ValueError 如果认证失败
        """
        if len(key) != 32:
            raise ValueError("密钥必须是32字节")
        if len(nonce) != 12:
            raise ValueError("nonce必须是12字节")
        if len(tag) != 16:
            raise ValueError("认证标签必须是16字节")

        # 生成Poly1305密钥
        poly_key = self._chacha20_block(key, nonce, 0)[:32]
        
        # 构造认证数据
        auth_data = aad + self._pad16(aad)
        auth_data += ciphertext + self._pad16(ciphertext)
        auth_data += len(aad).to_bytes(8, 'little')
        auth_data += len(ciphertext).to_bytes(8, 'little')
        
        # 验证认证标签
        expected_tag = self._poly1305_mac(auth_data, poly_key)
        if not self._constant_time_compare(tag, expected_tag):
            raise ValueError("认证失败：数据可能被篡改")
        
        # 重置计数器为1开始解密
        self.reset_counter(1)
        
        keystream = b''
        # 生成足够的密钥流
        while len(keystream) < len(ciphertext):
            keystream += self._chacha20_block(key, nonce, self.counter)
            self.counter += 1

        # 截取需要的密钥流长度
        keystream = keystream[:len(ciphertext)]

        # 异或解密
        plaintext = bytes(a ^ b for a, b in zip(ciphertext, keystream))
        
        return plaintext

    def _constant_time_compare(self, a: bytes, b: bytes) -> bool:
        """常数时间比较，防止时序攻击"""
        if len(a) != len(b):
            return False
        result = 0
        for x, y in zip(a, b):
            result |= x ^ y
        return result == 0

    def reset_counter(self, counter: int = 0) -> None:
        """
        重置计数器
        :param counter: 新的计数器值，默认为0
        """
        self.counter = counter


def hex_to_bytes(hex_string: str) -> bytes:
    """将十六进制字符串转换为字节"""
    return bytes.fromhex(hex_string)


def bytes_to_hex(data: bytes) -> str:
    """将字节转换为十六进制字符串"""
    return data.hex()


def main():
    """测试ChaCha20-Poly1305 AEAD功能"""

    # 用户提供的真实测试数据
    KEY_LOBBY = bytes.fromhex(
        "9ac488d72c2613e55c2e4d12125f5376555c159b78d2675b378259577f331c4e")
    KEY_GAME = bytes.fromhex(
        "402e9fd6462897e67150370a5036f192d9f7d694aba1f02d72e32d9e5332596c")
    NONCE = bytes.fromhex("e190f6998fcfebd5")

    from Crypto.Cipher import ChaCha20

    chacha = ChaCha20.new(key=KEY_LOBBY,nonce=NONCE)
    ret = chacha.decrypt(bytes.fromhex("7C22A7B9931547E06032AC85EF4B7103DB20B8AE9D325AF87BE039595B246B81FBD04C977FE196058B5FF4BE948360D23E167D5D430CEF03D8F396C1233C26683C444BE76F8797F15AEF338A578278E5F4C8E4B7ECC1EA3C408DC2F0ABE432A38EC49959AE6B90C59B777CED0049EEF8F9ECFEB80B3C777B2A12B21782018DCD3DB5BB6E3746CF5EF604611667BF246FA2391377AE832589B52C22746818BAAFDBFF83D0D3832D1425209FF2ABDCE35088E1400B413199B7BF1EE56EA8525E9BE851447F3C8D2749BB370AAB52A86B8E15295FFCF345B107FD4FB581702212DFC199B1AA353EB4C8E5FBF2C3402776002C168164B7E9BE6AA33419A11E407F34B0910A41B36DA668011B552E1B6208257801B608963CD7C06CD32A964B129AB10B7F1FF0"))
    print("结果:",ret.hex())
    chacha = ChaCha20.new(key=KEY_LOBBY, nonce=NONCE)
    ret = chacha.decrypt(bytes.fromhex("7CC897E4D9711EB3177FF7FAB32F0562B1C814E2"))
    print("结果:",ret.hex())
    chacha = ChaCha20.new(key=KEY_GAME, nonce=NONCE)
    ret = chacha.decrypt(bytes.fromhex("ACC69BAB7E"))
    print("结果:", ret.hex())

    chacha = ChaCha20.new(key=KEY_GAME, nonce=NONCE)
    ret = chacha.decrypt(bytes.fromhex("ACA99918030949C06DE528AFD528F365E503286F5349F18DF24042011276AF167AFF5931899099F540DB7FF53E816EFBFF504D5846CD75633088E66A7E2ECCB9FCBF8DE31269DA01673D29A88CA61FE35C4249867B43968D134E132626BD736716E202E16E1CA3FBD455F6F6F2891418F249EB45CA82AF38A20C3FEB459232E23C6E564E07C8235C442DB5D3C06835B5F34DFF1BD96DB16F9EF587B7E461771E7ED623709883EE2B467161C3FF523215D1F6763B402A5BB0E35B1888F4781584F467F7134B81F74ABA5CEF8DDED773550735756C9FAE579993A8491E7D06F5864AC575456490EF8B9D53D76F23C35D90C437B62F27B756BBF159E88C7A17F9A7A038DCB336BAFA92B9CD560D4C1F22BE6D5A61562DFEA2EA3C1309CC070A55A0151FE9AF"))
    print("结果:", ret.hex())
if __name__ == "__main__":
    main()